TMLE(Targeted Maximum Likelihood Estimation)に基づいた条件付き平均処置効果の推論
概要
TMLE(Targeted Maximum Likelihood Estimation)に対する理解を深めることを目的として、TMLEに基づいた条件付き平均処置効果(CATE)の推論を実施してみます。
スクリプトはこちらのノートブックファイルをベースに実施しました。
目次
1.最初に
本エントリーでは、「2.TMLE(Targeted Maximum Likelihood Estimation)とは」にて下記の資料を元にTMLEの概要について確認した後に、
- Targeted Maximum Likelihood Estimation: A Gentle Introduction
- Targeted maximum likelihood estimation for a binary treatment: A tutorial
「3.causalmlで推論してみる」で「causalml」のサンプルスクリプトを参考に、実際にTMLEに基づいた条件付き平均処置効果(CATE)を推論してみます。
2.TMLE(Targeted Maximum Likelihood Estimation)とは
まずはTMLEについてざっくり確認します。
「TMLEとはなんなのか」という点について、こちらから引用します。
TMLE is a general methodology that can be applied to estimation of many types of causal effect parameters, including but not limited to those involving point treatment effects, survival analysis, longitudinal data analysis, and genomics data.
参照:4 Discussion
処置効果を推論する様々なタスクに適用できる手法、といったところでしょうか。
計算フローについて、Targeted maximum likelihood estimation for a binary
treatment: A tutorialを参考に概要レベルで記述します。
Step1.目的変数を推論するモデルを学習
介入変数(A)、共変量(W)を利用して目的変数(Y)を推論するモデルを学習します。
介入有無に応じて、それぞれ別のモデルを利用することも可能です。
下記は本論文中で紹介されている計算式の例です。
(絶対に下記の通り計算しないといけない、という訳ではなくあくまでも一例としての計算式です)
Step2.傾向スコアの計算
共変量(W)に基づいて傾向スコアを計算します。
下記は本論文中で紹介されている計算式の例です。
ロジスティック回帰で求めた値を使って、傾向スコアを求める関数を計算しています。
Step3.4つの関数の計算
「Step2で求めた傾向スコアを使って、Step1で学習したモデルを補強する」ために、4つの関数を計算します。
(下記で言う所の「clever covariate」、εでそれぞれ介入/対照の2つずつ)
まずは「Step2」で求めた傾向スコアと介入変数(A)から下記のような関数を計算します。
(IPTW法(Inverse Probability of Treatment Weighting)と似ています)
(介入変数(A)は0か1の値が入っていることを仮定)
そして、ここまでの計算結果を利用して「ε0,ε1」を最尤推定に基づいて計算します。
(下記は本論文中で紹介されている計算式の例です)
この「ε0,ε1」は直感的には「Step1で求めたモデルの推論結果と目的変数の差分を補正する値」と言えます。
なので、差分が大きい時には大きな値を、差分が小さい時には0に近い値になります。
(ex.もし「Step1」で学習したモデルがとても精度良く目的変数を推論できている場合は、「ε0,ε1」は0に近い値ばかりになります)
Step4.モデルを更新
ここまでの計算結果を利用し、「Step1」のモデルを更新します。
本論文中では下記のように紹介されていますが、「Step3」で記述している通り「Step1」のモデルの精度が高いようならモデルの更新前後で殆ど変わらない(εが0に近いので)、ということがわかります。
Step5.CATEを計算
「Step4」で更新したモデルを利用し、下記の通りCATEが計算できます。
(Marginal Odds Ratio(MOR)も計算できます)
また、標準誤差も計算できるので、信頼区間を計算することもできます。
3.causalmlで推論してみる
TMLEの概要について確認したので、causalmlを使って実際にTMLEに基づいたCATEを計算してみます。
まずは今回利用するデータを準備します。
from causalml.dataset import synthetic_data from sklearn.model_selection import train_test_split y, X, treatment, tau, b, e = synthetic_data(mode=1, n=1000000, p=10, sigma=5.) X_train, X_test, y_train, y_test, e_train, e_test, treatment_train, treatment_test, tau_train, tau_test, b_train, b_test = train_test_split(X, y, e, treatment, tau, b, test_size=0.5, random_state=42)
上記で用意したデータに対して、下記のようなスクリプトで推論できます。
指定したパラメータに応じた信頼区間の下限/上限も計算できます。
# 1.交差検証等のパラメータを指定 from sklearn.model_selection import KFold n_fold = 5 kf = KFold(n_splits=n_fold) # 2.傾向スコアの計算 from causalml.propensity import compute_propensity_score p, _ = compute_propensity_score(X=X_train, treatment=treatment_train, # 学習用データで学習して X_pred=X_test, treatment_pred=treatment_test, # テスト用データの推論値を返却する cv=kf,calibrate_p=True ) # 3.base-learnerと学習基の初期化 from lightgbm import LGBMRegressor from causalml.inference.meta import TMLELearner base_learner = LGBMRegressor(num_leaves=64, learning_rate=.05, n_estimators=300) tmle = TMLELearner(learner= base_learner, ate_alpha=.05, cv=kf, calibrate_propensity=True) # 4.CATEの推論 cate_xlearner, cate_xlearner_lb, cate_xlearner_ub = tmle.estimate_ate(X=X_test, p=p, treatment=treatment_test, y=y_test) print(cate_xlearner, cate_xlearner_lb, cate_xlearner_ub)
[0.51930395] [0.4825293] [0.5560786]
また、下記のようにセグメントとしてデータを分割してCATEを計算することもできます。
下記では「tau_test」の値に応じて、「n_segment」変数に指定したセグメント数に分割してCATEを計算しています。
分析/検証のいずれの用途においてもお世話になりそうです。
import pandas as pd n_segment = 5 cate_xlearner_seg, cate_xlearner_lb_seg, cate_xlearner_ub_seg = tmle.estimate_ate(X=X_test, p=e_test, treatment=treatment_test, y=y_test, segment=pd.qcut(tau_test, n_segment, labels=False) ) print(cate_xlearner_seg, cate_xlearner_lb_seg, cate_xlearner_ub_seg)
[[0.1485507 0.38580702 0.52353701 0.70505099 0.82688852]] [[0.05304797 0.31081193 0.45404529 0.61390115 0.73205315]] [[0.24405343 0.46080211 0.59302873 0.79620083 0.92172388]]
4.最後に
TMLEについて、計算フローとcausalmlでの実行方法をざっくり確認しました。
自分で実装するとなると大変そうですが、パッケージ側で実装してくれているのはありがたいですね。